Amazon Research Tuebingen
Abstract:Recently, sparse autoencoders (SAEs) have emerged as an attractive tool for interpreting and interacting with representations in practical neural networks. While it is common empirical folklore, we also show theoretically that SAEs are highly unstable: different training runs are likely to produce different concept dictionaries and sparse codes. We characterize the model properties that hinder the stability of real-world SAEs, and address each of these problems through minimal changes to the architecture and training procedure. Together, these changes yield two versions of an \textbf{i}dentifiable SAE (iSAE), a variant of the standard TopK SAE with lower reconstruction error and improved stability. We explain this improvement theoretically by connecting SAEs with traditional dictionary learning approaches, and show that the dictionaries learned in practice satisfy an approximate restricted isometry condition, rendering the corresponding sparse codes in those models near-identifiable.
Abstract:While Vision Foundation Models (VFMs) excel at predictive tasks on satellite imagery, their performance can arise from visual correlations rather than underlying structural invariants, making even perception-based out-of-distribution accuracy a poor proxy for scientific utility. As a result, models may look correct without reasoning correctly, a discrepancy we term the Perception-Physics Paradox. To address this gap, we introduce scientific alignment as an implicit objective for representation learning in scientific domains. We study a principled, testable aspect of scientific alignment through structural isomorphism, which requires latent representations to uniquely identify physical systems up to a linear reparameterization. This perspective induces a hierarchy of necessary conditions and yields a systematic probing protocol for physical and causal interpretability. To operationalize this framework, we release TC-Bench, a global, reproducible benchmark dataset with an automated construction pipeline for tropical cyclone research, and show that current VFMs rely on visual shortcuts that collapse in intense regimes, indicating that scientific alignment does not arise as a natural byproduct of scaling alone.
Abstract:Sparse Autoencoders (SAEs) that can accurately reconstruct their input (minimizing distortion) by making efficient use of few features (minimizing the rate) often fail to learn monosemantic representations (highly interpretable), limiting their usefulness for mechanistic interpretability. In this paper, we characterise this tension in learning faithful, efficient, and interpretable explanations, introducing the Rate-Distortion-Polysemanticity tradeoff in SAEs. Under toy-modeling assumptions, we theoretically and empirically show that restricting the SAE to be monosemantic necessarily comes with an increase in rate and distortion. Assuming a generative model behind the input observations, we further demonstrate that the degree of polysemanticity of optimal SAEs is determined by the training data distribution, especially by the probability of features to co-occur. Finally, we extend the analysis to real-world settings by deriving necessary conditions that a polysemanticity measure should satisfy when the data-generating process is unknown, and we benchmark existing proxy metrics on SAEs trained on Large Language Models. Taken together, our findings show that polysemanticity is a data problem that should be accounted for when addressing it at the architectural and optimization level.
Abstract:Causal discovery, the problem of inferring the direction of causality, is generally ill-posed. We use the language of structural causal models (SCM) to show that assuming that the causal relations are acyclic and invariant across multiple environments (e.g., the way minimum wage affects employment rate is stable across different geographical regions), \textit{only} two auxiliary environments are sufficient to infer the causal graph for arbitrary nonlinear mechanisms. Moreover, we demonstrate that this implies identifiability of the SCM functional mechanisms: as a corollary, we show that \textit{two} auxiliary environments are sufficient to guarantee correct counterfactual inference. We empirically support our theoretical results on synthetic data.
Abstract:Selection bias is pervasive in observational studies. For example, large scale biobanks data can exhibit ``healthy volunteer bias'' when respondents are healthier and of higher socio-economic status than the population they are meant to represent. Recovering causal effects from such sub-population is an important problem in causal inference, as estimating average treatment effects (ATE) from selected populations can result in a severely biased estimate of the ATE from the whole population. In this paper, we investigate the identifiability of the ATE under selection bias. We provide necessary and sufficient conditions for ATE identifiability, leveraging weak assumptions on probability classes to characterize propensity score and selection probability. Compared to previous works, our results extend existing graphical identifiability criteria and offer a more comprehensive understanding of causal effect identification with strictly weaker conditions in the presence of selection bias.
Abstract:Diffusion-based models on continuous spaces have seen substantial recent progress through the mathematical framework of gradient flows, leveraging the Wasserstein-2 (${W}_2$) metric via the Jordan-Kinderlehrer-Otto (JKO) scheme. Despite the increasing popularity of diffusion models on discrete spaces using continuous-time Markov chains, a parallel theoretical framework based on gradient flows has remained elusive due to intrinsic challenges in translating the ${W}_2$ distance directly into these settings. In this work, we propose the first computational approach addressing these challenges, leveraging an appropriate metric $W_K$ on the simplex of probability distributions, which enables us to interpret widely used discrete diffusion paths, such as the discrete heat equation, as gradient flows of specific free-energy functionals. Through this theoretical insight, we introduce a novel methodology for learning diffusion dynamics over discrete spaces, which recovers the underlying functional directly by leveraging first-order optimality conditions for the JKO scheme. The resulting method optimizes a simple quadratic loss, trains extremely fast, does not require individual sample trajectories, and only needs a numerical preprocessing computing $W_K$-geodesics. We validate our method through extensive numerical experiments on synthetic data, showing that we can recover the underlying functional for a variety of graph classes.
Abstract:Representation learning models exhibit a surprising stability in their internal representations. Whereas most prior work treats this stability as a single property, we formalize it as two distinct concepts: statistical identifiability (consistency of representations across runs) and structural identifiability (alignment of representations with some unobserved ground truth). Recognizing that perfect pointwise identifiability is generally unrealistic for modern representation learning models, we propose new model-agnostic definitions of statistical and structural near-identifiability of representations up to some error tolerance $ε$. Leveraging these definitions, we prove a statistical $ε$-near-identifiability result for the representations of models with nonlinear decoders, generalizing existing identifiability theory beyond last-layer representations in e.g. generative pre-trained transformers (GPTs) to near-identifiability of the intermediate representations of a broad class of models including (masked) autoencoders (MAEs) and supervised learners. Although these weaker assumptions confer weaker identifiability, we show that independent components analysis (ICA) can resolve much of the remaining linear ambiguity for this class of models, and validate and measure our near-identifiability claims empirically. With additional assumptions on the data-generating process, statistical identifiability extends to structural identifiability, yielding a simple and practical recipe for disentanglement: ICA post-processing of latent representations. On synthetic benchmarks, this approach achieves state-of-the-art disentanglement using a vanilla autoencoder. With a foundation model-scale MAE for cell microscopy, it disentangles biological variation from technical batch effects, substantially improving downstream generalization.
Abstract:Mendelian Randomization (MR) is a prominent observational epidemiological research method designed to address unobserved confounding when estimating causal effects. However, core assumptions -- particularly the independence between instruments and unobserved confounders -- are often violated due to population stratification or assortative mating. Leveraging the increasing availability of multi-environment data, we propose a representation learning framework that exploits cross-environment invariance to recover latent exogenous components of genetic instruments. We provide theoretical guarantees for identifying these latent instruments under various mixing mechanisms and demonstrate the effectiveness of our approach through simulations and semi-synthetic experiments using data from the All of Us Research Hub.
Abstract:Pretraining and fine-tuning are central stages in modern machine learning systems. In practice, feature learning plays an important role across both stages: deep neural networks learn a broad range of useful features during pretraining and further refine those features during fine-tuning. However, an end-to-end theoretical understanding of how choices of initialization impact the ability to reuse and refine features during fine-tuning has remained elusive. Here we develop an analytical theory of the pretraining-fine-tuning pipeline in diagonal linear networks, deriving exact expressions for the generalization error as a function of initialization parameters and task statistics. We find that different initialization choices place the network into four distinct fine-tuning regimes that are distinguished by their ability to support feature learning and reuse, and therefore by the task statistics for which they are beneficial. In particular, a smaller initialization scale in earlier layers enables the network to both reuse and refine its features, leading to superior generalization on fine-tuning tasks that rely on a subset of pretraining features. We demonstrate empirically that the same initialization parameters impact generalization in nonlinear networks trained on CIFAR-100. Overall, our results demonstrate analytically how data and network initialization interact to shape fine-tuning generalization, highlighting an important role for the relative scale of initialization across different layers in enabling continued feature learning during fine-tuning.
Abstract:Language and vision-language models have shown impressive performance across a wide range of tasks, but their internal mechanisms remain only partly understood. In this work, we study how individual attention heads in text-generative models specialize in specific semantic or visual attributes. Building on an established interpretability method, we reinterpret the practice of probing intermediate activations with the final decoding layer through the lens of signal processing. This lets us analyze multiple samples in a principled way and rank attention heads based on their relevance to target concepts. Our results show consistent patterns of specialization at the head level across both unimodal and multimodal transformers. Remarkably, we find that editing as few as 1% of the heads, selected using our method, can reliably suppress or enhance targeted concepts in the model output. We validate our approach on language tasks such as question answering and toxicity mitigation, as well as vision-language tasks including image classification and captioning. Our findings highlight an interpretable and controllable structure within attention layers, offering simple tools for understanding and editing large-scale generative models.